from typing import Optional

from omegaconf import DictConfig
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from einops.layers.torch import Rearrange

from .abstract import Elasticity, Plasticity
from .utils import get_nonlinearity, get_norm, init_weight
from .loralib import (
    mark_only_lora_as_trainable,
    mark_lora_and_phys_as_trainable,
    lora_state_dict,
    replace_with_linear_lora,
    init_linear_lora,
    LinearLoRA
)

import math


class MLPBlock(nn.Module):
    def __init__(
            self,
            in_planes: int,
            out_planes: int,
            no_bias: bool,
            norm: Optional[str],
            nonlinearity: Optional[str]) -> None:

        super().__init__()
        if norm == 'wn':
            self.fc = nn.utils.weight_norm(nn.Linear(in_planes, out_planes, not no_bias))
        else:
            self.fc = nn.Linear(in_planes, out_planes, bias=not no_bias and norm is None)
        self.norm = get_norm(norm, out_planes, dim=1, affine=not no_bias)
        self.nonlinearity = get_nonlinearity(nonlinearity)

    def forward(self, x: Tensor) -> Tensor:

        x = self.fc(x)
        x = self.norm(x)
        x = self.nonlinearity(x)
        return x





################### LXY




class MultiScaleBlock(nn.Module):
    def __init__(
            self,
            base_block: MLPBlock
    ):
        super().__init__()
        self.fc = base_block.fc
        self.norm = base_block.norm
        self.nonlinearity = base_block.nonlinearity
        
        self.elastic_lora = MLPBlock(
                in_planes=self.fc.in_features,
                out_planes=self.fc.out_features,
                no_bias=cfg.no_bias,
                norm=cfg.norm,
                nonlinearity=cfg.nonlinearity
            )        
        
        self.plastic_lora = MLPBlock(
                in_planes=self.fc.in_features,
                out_planes=self.fc.out_features,
                no_bias=cfg.no_bias,
                norm=cfg.norm,
                nonlinearity=cfg.nonlinearity
            )
        self.visco_lora = MLPBlock(
                in_planes=self.fc.in_features,
                out_planes=self.fc.out_features,
                no_bias=cfg.no_bias,
                norm=cfg.norm,
                nonlinearity=cfg.nonlinearity
            )
        
        self.mask_predictor = nn.Sequential(
            nn.Linear(13, 8),
            nn.ReLU(),
            nn.Linear(8, 3),
            nn.Sigmoid()
        )

        for param in self.fc.parameters():
            param.requires_grad = False

    def forward(self, x: Tensor, invariants: Tensor) -> Tensor:
        
        x_base = self.fc(x)
        
        
        mask = self.mask_predictor(invariants)  # [B,3]
        delta_elastic = self.elastic_lora(x) * mask[:, 0:1]
        delta_plastic = self.plastic_lora(x) * mask[:, 1:2]
        delta_visco = self.visco_lora(x) * mask[:, 2:3]
        
        x_out = x_base + delta_elastic + delta_plastic + delta_visco
        
        x_out = self.norm(x_out)
        x_out = self.nonlinearity(x_out)
        return x_out


################### LXY








class MetaElasticity(Elasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.flatten = Rearrange('b d1 d2 -> b (d1 d2)', d1=self.dim, d2=self.dim)
        self.unflatten = Rearrange('b (d1 d2) -> b d1 d2', d1=self.dim, d2=self.dim)

        self.normalize_input: bool = cfg.normalize_input

    def forward(self, F: Tensor) -> Tensor:
        raise NotImplementedError


class PlainMetaElasticity(MetaElasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = self.dim * self.dim
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)

    def forward(self, F: Tensor) -> Tensor:
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)

        if self.normalize_input:
            x = self.flatten(F - I)
        else:
            x = self.flatten(F)
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)
        P = self.unflatten(x)
        cauchy = torch.matmul(P, self.transpose(F))
        return cauchy


class PolarMetaElasticity(MetaElasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = self.dim * self.dim
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)

    def forward(self, F: Tensor) -> Tensor:
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)

        U, sigma, Vh = self.svd(F)

        R = torch.matmul(U, Vh)
        S = torch.matmul(torch.matmul(self.transpose(Vh), torch.diag_embed(sigma)), Vh)

        if self.normalize_input:
            x = self.flatten(S - I)
        else:
            x = self.flatten(S)
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)
        x = self.unflatten(x)
        x = 0.5 * (self.transpose(x) + x)
        P = torch.matmul(R, x)
        cauchy = torch.matmul(P, self.transpose(F))
        return cauchy


class InvariantMetaElasticity(MetaElasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = 3
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)

    def forward(self, F: Tensor) -> Tensor:
        U, sigma, Vh = self.svd(F)

        R = torch.matmul(U, Vh)

        Ft = self.transpose(F)

        if self.normalize_input:
            I1 = sigma.sum(dim=1) - 3.0
            I2 = torch.diagonal(torch.matmul(Ft, F), dim1=1, dim2=2).sum(dim=1) - 1.0
            I3 = torch.linalg.det(F) - 1.0
        else:
            I1 = sigma.sum(dim=1)
            I2 = torch.diagonal(torch.matmul(Ft, F), dim1=1, dim2=2).sum(dim=1)
            I3 = torch.linalg.det(F)

        x = torch.stack([I1, I2, I3], dim=1)
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)
        x = self.unflatten(x)
        x = 0.5 * (self.transpose(x) + x)
        P = torch.matmul(R, x)
        cauchy = torch.matmul(P, self.transpose(F))
        return cauchy


class InvariantFullMetaElasticity(MetaElasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = self.dim + self.dim * self.dim + 1
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)
    
    def init_lora_layers(self, r: int, lora_alpha: int = 1):
        replace_with_linear_lora(self, nn.Linear, r, lora_alpha)
        print(f"Initialized LoRA layers in {type(self).__name__} with r={r} and alpha={lora_alpha}.")
    
    def freeze_all_except_lora(self):
        mark_only_lora_as_trainable(self)
    
    def lora_state_dict(self, bias: str = 'none'):
        return lora_state_dict(self, bias)

    def forward(self, F: Tensor) -> Tensor:
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)
        U, sigma, Vh = self.svd(F)
        R = torch.matmul(U, Vh)

        Ft = self.transpose(F)
        FtF = torch.matmul(Ft, F)

        if self.normalize_input:
            I1 = sigma - 1.0
            I2 = self.flatten(FtF - I)
            I3 = torch.linalg.det(F).unsqueeze(dim=1) - 1.0
        else:
            I1 = sigma
            I2 = self.flatten(FtF)
            I3 = torch.linalg.det(F).unsqueeze(dim=1)

        x = torch.cat([I1, I2, I3], dim=1)
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)
        x = self.unflatten(x)
        x = 0.5 * (self.transpose(x) + x)
        P = torch.matmul(R, x)
        cauchy = torch.matmul(P, Ft)
        return cauchy



#################### LXY

class InvariantFullMetaPhyElasticity_old(MetaElasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = self.dim + self.dim * self.dim + 1
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)
    
    def init_lora_layers(self, r: int, lora_alpha: int = 1):
        replace_with_linear_lora(self, nn.Linear, r, lora_alpha)
        print(f"Initialized LoRA layers in {type(self).__name__} with r={r} and alpha={lora_alpha}.")
    
    def freeze_all_except_lora(self):
        mark_only_lora_as_trainable(self)
    
    def lora_state_dict(self, bias: str = 'none'):
        return lora_state_dict(self, bias)

    def forward(self, F: Tensor) -> Tensor:
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)
        U, sigma, Vh = self.svd(F)
        R = torch.matmul(U, Vh)

        Ft = self.transpose(F)
        FtF = torch.matmul(Ft, F)

        if self.normalize_input:
            I1 = sigma - 1.0
            I2 = self.flatten(FtF - I)
            I3 = torch.linalg.det(F).unsqueeze(dim=1) - 1.0
        else:
            I1 = sigma
            I2 = self.flatten(FtF)
            I3 = torch.linalg.det(F).unsqueeze(dim=1)

        x = torch.cat([I1, I2, I3], dim=1)
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)
        x = self.unflatten(x)
        x = 0.5 * (self.transpose(x) + x)
        P = torch.matmul(R, x)
        cauchy = torch.matmul(P, Ft)
        return cauchy


class InvariantFullMetaPhyElasticity(MetaElasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)
        
        # main
        self.main_layers = nn.ModuleList()
        width = self.dim + self.dim * self.dim + 1
        for next_width in cfg.layer_widths:
            self.main_layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width
        self.main_final = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        # phys
        self.phys_layers = nn.ModuleList()
        phys_width = self.dim + self.dim * self.dim + 1
        for i, next_width in enumerate(cfg.phys_layer_widths):
            layer = MLPBlock(phys_width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity)
            layer._layer_type = "phys"
            #self._zero_init_phys_block(layer)
            self.phys_layers.add_module(f"phys_layer_{i}", layer)
            phys_width = next_width
        self.phys_final = MLPBlock(phys_width, self.dim * self.dim, cfg.no_bias, None, None)
        self.phys_final._layer_type = "phys"
        #self._zero_init_phys_block(self.phys_final)

        #self._init_phys_params()
        for m in self.modules():
            if not m._get_name().startswith("phys_"):
                init_weight(m)

    def _init_phys_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) and "phys_" in m._get_name():
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.1)

    def init_lora_layers(self, r: int, lora_alpha: int = 1):
        """
        def _replace_main_linear_lora(module):
            for name, child in module.named_children():
                if isinstance(child, nn.Linear) and "phys_" not in name:
                    lora_layer = LinearWithLoRA(
                        child.in_features, 
                        child.out_features,
                        r=r,
                        lora_alpha=lora_alpha
                    )
                    lora_layer.weight = child.weight
                    if child.bias is not None:
                        lora_layer.bias = child.bias
                    setattr(module, name, lora_layer)
                else:
                    _replace_main_linear_lora(child)
        
        _replace_main_linear_lora(self.main_layers)
        _replace_main_linear_lora(self.main_final)
        print(f"Initialized LoRA only on main branch with r={r}")
        """
        replace_with_linear_lora(self, nn.Linear, r, lora_alpha)
        print(f"Initialized LoRA layers in {type(self).__name__} with r={r} and alpha={lora_alpha}.")

    def load_pretrained(self, pretrained_source: str):
        if isinstance(pretrained_source, str):
            pretrained_dict = torch.load(pretrained_source, map_location='cpu')
        elif isinstance(pretrained_source, dict):
            pretrained_dict = pretrained_source
        else:
            raise TypeError(f"Unsupported pretrained source type: {type(pretrained_source)}")
        
        model_dict = self.state_dict()
        filtered_dict = {k: v for k, v in pretrained_dict.items() 
                         if k in model_dict 
                         and "phys_" not in k 
                         and v.shape == model_dict[k].shape}
        
        model_dict.update(filtered_dict)
        self.load_state_dict(model_dict, strict=False)
        print(f"Loaded {len(filtered_dict)}/{len(pretrained_dict)} main branch params")

    def forward(self, F: Tensor) -> Tensor:
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)
        U, sigma, Vh = self.svd(F)
        R = torch.matmul(U, Vh)

        Ft = self.transpose(F)
        FtF = torch.matmul(Ft, F)

        if self.normalize_input:
            I1 = sigma - 1.0
            I2 = self.flatten(FtF - I)
            I3 = torch.linalg.det(F).unsqueeze(dim=1) - 1.0
        else:
            I1 = sigma
            I2 = self.flatten(FtF)
            I3 = torch.linalg.det(F).unsqueeze(dim=1)

        x_cat = torch.cat([I1, I2, I3], dim=1)
        x_main = x_cat
        x_phys = x_cat.detach() + (x_cat - x_cat.detach()).requires_grad_(True)
        #print(f"x_phys.shape", x_phys.shape)
        
        for layer in self.main_layers:
            x_main = layer(x_main)
        x_main = self.main_final(x_main)
        
        #x_phys = torch.cat([I1, I2, I3], dim=1)
        #x_phys = x_main.detach().clone()
        for layer in self.phys_layers:
            x_phys = layer(x_phys)
        x_phys = self.phys_final(x_phys)

        x = x_main + x_phys * 0.0

        
        x = self.unflatten(x)
        x = 0.5 * (self.transpose(x) + x)
        P = torch.matmul(R, x)
        cauchy = torch.matmul(P, Ft)
        
        #print(f"cauchy.shape", cauchy.shape)
        
        return cauchy
           

    def get_parameter_groups(self):
        return {
            "lora_params": [
                p for n, p in self.named_parameters() 
                if "lora_" in n and "phys_" not in n
            ],
            
            "phys_params": [
                p for n, p in self.named_parameters() 
                if "phys_" in n
            ],
            
            "frozen_main": [
                p for n, p in self.named_parameters() 
                if "phys_" not in n and "lora_" not in n
            ]
        }

    def setup_optimizers(self, lr_lora=1e-4, lr_phys=1e-3):
        groups = self.get_parameter_groups()
        
        # freeze main
        for p in groups["frozen_main"]:
            p.requires_grad = False
            
        # image optimizer
        opt_lora = torch.optim.AdamW(
            groups["lora_params"],
            lr=lr_lora
        )
        
        # phys optimizer
        opt_phys = torch.optim.Adam(
            groups["phys_params"],
            lr=lr_phys,
            weight_decay=1e-4
        )
        
        return opt_lora, opt_phys
        
        
    def mark_trainable(self, bias):
        mark_lora_and_phys_as_trainable(self, bias)

    def freeze_all_except_lora(self):
        # freeze non-lora
        for p in self.parameters():
            if "lora_" not in p._get_name():
                p.requires_grad = False
            else:
                p.requires_grad = True
        # freeze phys
        for p in self.phys_params():
            p.requires_grad = False
            
    def freeze_main(self):
        for p in self.parameters():
            if "lora_" not in p._get_name():
                p.requires_grad = False
            else:
                p.requires_grad = False
        # freeze phys
        for p in self.phys_params():
            p.requires_grad = True
                
    def freeze_all_except_lora_and_phys(self):
        for p in self.parameters():
            if "lora_" not in p._get_name():
                p.requires_grad = True
            else:
                p.requires_grad = False
        # freeze phys
        for p in self.phys_params():
            p.requires_grad = True
                
    def _zero_init_phys_block(self, block: MLPBlock):


        nn.init.zeros_(block.fc.weight)
        if block.fc.bias is not None:
            nn.init.zeros_(block.fc.bias)
        
        if block.norm is not None:
            if isinstance(block.norm, (nn.BatchNorm1d, nn.LayerNorm)):
                nn.init.zeros_(block.norm.weight)
                if block.norm.bias is not None:
                    nn.init.zeros_(block.norm.bias)
    
    def lora_state_dict(self, bias: str = 'none'):
        return lora_state_dict(self, bias)


#################### LXY






class SVDMetaElasticity(MetaElasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = self.dim
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)

    def forward(self, F: Tensor) -> Tensor:
        Ft = self.transpose(F)

        U, sigma, Vh = self.svd(F)

        if self.normalize_input:
            x = sigma - 1.0
        else:
            x = sigma
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)

        P = torch.matmul(torch.matmul(U, torch.diag_embed(x)), Vh)

        cauchy = torch.matmul(P, Ft)
        return cauchy


# http://viterbi-web.usc.edu/~jbarbic/isotropicMaterialEditor/XuSinZhuBarbic-Siggraph2015.pdf
class SplineMetaElasticity(MetaElasticity):

    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.num_side_points: int = cfg.num_side_points
        self.xk_min: float = 0.0
        self.xk_max: float = cfg.xk_max
        self.yk_min: float = -cfg.yk_max
        self.yk_max: float = cfg.yk_max

        self.npoints = 2 * self.num_side_points + 1
        left_points = np.linspace(self.xk_min, 1.0, cfg.num_side_points + 1)
        right_points = np.linspace(1.0, self.xk_max , cfg.num_side_points + 1)
        xk = torch.tensor(left_points.tolist()[:-1] + [1.0] + right_points.tolist()[1:])
        self.register_buffer('xk', xk)

        w = torch.tensor([
            [-1.0, 3.0, -3.0, 1.0],
            [3.0, -6.0, 3.0, 0.0],
            [-3.0, 3.0, 0.0, 0.0],
            [-1.0, 0.0, 0.0, 0.0],
        ]).view(1, 4, 4)
        self.register_buffer('w', w)

        # if cfg.E is not None and cfg.nu is not None:
        #     E = cfg.E
        #     nu = cfg.nu
        #     mu = E / (2 * (1 + nu))
        #     la = E * nu / ((1 + nu) * (1 - 2 * nu))

        #     self.yk_f = nn.Parameter(la * xk - 3 * la + 2 * mu * (xk - 1))
        #     self.yk_g = nn.Parameter(torch.ones_like(xk) * la)
        #     self.yk_h = nn.Parameter(torch.zeros_like(xk))
        # else:

        self.yk_f = nn.Parameter(torch.linspace(self.yk_min, self.yk_max, xk.size(0)))
        self.yk_g = nn.Parameter(torch.linspace(self.yk_min, self.yk_max, xk.size(0)))
        self.yk_h = nn.Parameter(torch.linspace(self.yk_min, self.yk_max, xk.size(0)))

    def get_ak(self, yk):
        ak_1 = 2 / 3 * yk[0] + 2 / 3 * yk[1] - 1 / 3 * yk[2]
        ak_else = yk[1:-1] - 1 / 6 * yk[:-2] + 1 / 6 * yk[2:]
        return torch.cat([ak_1.unsqueeze(0), ak_else], dim=0)

    def get_bk(self, yk):
        bk_else = yk[1:-1] + 1 / 6 * yk[:-2] - 1 / 6 * yk[2:]
        bk_m = 2 / 3 * yk[-1] + 2 / 3 * yk[-2] - 1 / 3 * yk[-3]
        return torch.cat([bk_else, bk_m.unsqueeze(0)], dim=0)

    def get_func(self, yk, lambd):
        indices = torch.searchsorted(self.xk, lambd, right=False).view(-1)
        indices[indices < 0] = 0
        indices[indices > self.num_side_points - 1] = self.num_side_points - 1

        ak = self.get_ak(yk)
        bk = self.get_bk(yk)

        y_left = yk[indices].view_as(lambd)
        y_right = yk[indices + 1].view_as(lambd)
        a = ak[indices].view_as(lambd)
        b = bk[indices].view_as(lambd)
        temp_right = torch.stack([y_left, a, b, y_right], dim=2)

        xi = (lambd - self.xk[indices].view_as(lambd)) / (self.xk[indices + 1].view_as(lambd) - self.xk[indices].view_as(lambd))
        xi_vector = torch.stack([xi**3, xi**2, xi, torch.ones_like(xi)], dim=2) # batch, #lambda, 4

        temp_left = torch.matmul(xi_vector, self.w) # batch, #lambda, 4
        func = (temp_left * temp_right).sum(dim=2) # batch, #lambda

        return func

    def forward(self, F: Tensor) -> Tensor:
        U, sigma, Vh = self.svd(F)

        f = self.get_func(self.yk_f, sigma)

        areas = torch.stack([
            sigma[:, 0] * sigma[:, 1],
            sigma[:, 1] * sigma[:, 2],
            sigma[:, 0] * sigma[:, 2]], dim=1)
        g = self.get_func(self.yk_g, areas)

        g1 = g[:, [0, 0, 2]] * sigma[:, [1, 0, 0]]
        g2 = g[:, [2, 1, 1]] * sigma[:, [2, 2, 1]]

        volume = (sigma[:, 0] * sigma[:, 1] * sigma[:, 2]).unsqueeze(1)
        h = self.get_func(self.yk_h, volume) * sigma[:, [1, 0, 0]] * sigma[:, [2, 2, 1]]

        new_sigma = f + g1 + g2 + h
        P = torch.matmul(torch.matmul(U, torch.diag_embed(new_sigma)), Vh)

        Ft = self.transpose(F)

        cauchy = torch.matmul(P, Ft)
        return cauchy


class MetaPlasticity(Plasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.alpha: float = cfg.alpha

        self.flatten = Rearrange('b d1 d2 -> b (d1 d2)', d1=self.dim, d2=self.dim)
        self.unflatten = Rearrange('b (d1 d2) -> b d1 d2', d1=self.dim, d2=self.dim)

        self.normalize_input: bool = cfg.normalize_input

    def forward(self, F: Tensor) -> Tensor:
        raise NotImplementedError


class PlainMetaPlasticity(MetaPlasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = self.dim * self.dim
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)

    def forward(self, F: Tensor) -> Tensor:
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)

        if self.normalize_input:
            x = self.flatten(F - I)
        else:
            x = self.flatten(F)
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)
        delta_Fp = self.alpha * self.unflatten(x)
        Fp = delta_Fp + F
        return Fp


class PolarMetaPlasticity(MetaPlasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = self.dim * self.dim
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)

    def forward(self, F: Tensor) -> Tensor:
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)

        U, sigma, Vh = self.svd(F)

        R = torch.matmul(U, Vh)
        S = torch.matmul(torch.matmul(self.transpose(Vh), torch.diag_embed(sigma)), Vh)

        if self.normalize_input:
            x = self.flatten(S - I)
        else:
            x = self.flatten(S)
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)
        x = self.unflatten(x)
        x = 0.5 * (self.transpose(x) + x)
        delta_Fp = self.alpha * torch.matmul(R, x)
        Fp = delta_Fp + F
        return Fp


class InvariantFullMetaPlasticity(MetaPlasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = 3 + 9 + 1
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim * self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)
    
    def init_lora_layers(self, r: int, lora_alpha: int = 1):
        replace_with_linear_lora(self, nn.Linear, r, lora_alpha)
        print(f"Initialized LoRA layers in {type(self).__name__} with r={r} and alpha={lora_alpha}.")
    
    def freeze_all_except_lora(self):
        mark_only_lora_as_trainable(self)
    
    def lora_state_dict(self, bias: str = 'none'):
        return lora_state_dict(self, bias)

    def forward(self, F: Tensor) -> Tensor:
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)
        U, sigma, Vh = self.svd(F)
        R = torch.matmul(U, Vh)

        Ft = self.transpose(F)
        FtF = torch.matmul(Ft, F)

        I1 = sigma - 1.0
        I2 = self.flatten(FtF - I)
        I3 = torch.linalg.det(F).unsqueeze(dim=1) - 1.0

        invariants = torch.cat([I1, I2, I3], dim=1)
        x = invariants
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)
        x = self.unflatten(x)
        x = 0.5 * (self.transpose(x) + x)
        delta_Fp = self.alpha * torch.matmul(R, x)
        Fp = delta_Fp + F
        return Fp

##################### LXY

class InvariantFullMetaMLElasticity_old(MetaElasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)
        self.layers = nn.ModuleList()
        width = self.dim + self.dim * self.dim + 1
        
        for next_width in cfg.layer_widths:
            base_block = MLPBlock(
                in_planes=width,
                out_planes=next_width,
                no_bias=cfg.no_bias,
                norm=cfg.norm,
                nonlinearity=cfg.nonlinearity
            )
            lora_block = MultiScaleLoRABlock(
                base_block=base_block,
                r_elastic=cfg.lora.r_elastic,
                r_plastic=cfg.lora.r_plastic,
                r_visco=cfg.lora.r_visco,
                lora_alpha=cfg.lora.alpha
            )
            self.layers.append(lora_block)
            width = next_width
        
        base_final = MLPBlock(
            in_planes=width,
            out_planes=self.dim * self.dim,
            no_bias=cfg.no_bias,
            norm=None,
            nonlinearity=None
        )
        self.final_layer = MultiScaleLoRABlock(
            base_block=base_final,
            r_elastic=cfg.lora.r_elastic,
            r_plastic=cfg.lora.r_plastic,
            r_visco=cfg.lora.r_visco,
            lora_alpha=cfg.lora.alpha
        )
        
        for m in self.modules():
            init_weight(m)

        self._init_lora()

    def _init_lora(self):
    
        
        for layer in self.layers:
            for lora in [layer.elastic_lora, layer.plastic_lora, layer.visco_lora]:
                nn.init.kaiming_uniform_(lora.lora_A, a=math.sqrt(5))
                nn.init.zeros_(lora.lora_B)
                
                
    def init_lora_layers(self, r: int, lora_alpha: int = 1):
        replace_with_linear_lora(self, nn.Linear, r, lora_alpha)
        print(f"Initialized LoRA layers in {type(self).__name__} with r={r} and alpha={lora_alpha}.")
        

        
    def freeze_all_except_lora(self):
        mark_only_lora_as_trainable(self)
        

    def forward(self, F: Tensor) -> Tensor:
        
        I = torch.eye(self.dim, dtype=F.dtype, device=F.device, requires_grad=False)
        U, sigma, Vh = self.svd(F)
        R = torch.matmul(U, Vh)
        Ft = self.transpose(F)
        FtF = torch.matmul(Ft, F)
        if self.normalize_input:
            I1 = sigma - 1.0
            I2 = self.flatten(FtF - I)
            I3 = torch.linalg.det(F).unsqueeze(dim=1) - 1.0
        else:
            I1 = sigma
            I2 = self.flatten(FtF)
            I3 = torch.linalg.det(F).unsqueeze(dim=1)
        invariants = torch.cat([I1, I2, I3], dim=1)
        

        x = torch.cat([I1, I2, I3], dim=1)
        for layer in self.layers:
            x = layer(x, invariants)
        x = self.final_layer(x, invariants)
        
        x = self.unflatten(x)
        x = 0.5 * (self.transpose(x) + x)
        P = torch.matmul(R, x)
        cauchy = torch.matmul(P, Ft)
        return cauchy

    def load_pretrained(self, pretrained_dict: dict):
        model_dict = self.state_dict()
        pretrained_dict = {
            k: v for k, v in pretrained_dict.items()
            if k in model_dict and "lora_" not in k
        }
        model_dict.update(pretrained_dict)
        self.load_state_dict(model_dict, strict=False)
    
    def freeze_all_except_lora(self):
        for name, param in self.named_parameters():
            if "lora_" not in name:
                param.requires_grad = False
                
                
                
                



##################### LXY

class SplineMetaPlasticity(MetaPlasticity):

    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.num_side_points: int = cfg.num_side_points
        self.xk_min: float = 0.0
        self.xk_max: float = cfg.xk_max
        self.yk_min: float = -cfg.yk_max
        self.yk_max: float = cfg.yk_max

        self.npoints = 2 * self.num_side_points + 1
        left_points = np.linspace(self.xk_min, 1.0, cfg.num_side_points + 1)
        right_points = np.linspace(1.0, self.xk_max , cfg.num_side_points + 1)
        xk = torch.tensor(left_points.tolist()[:-1] + [1.0] + right_points.tolist()[1:])
        self.register_buffer('xk', xk)

        w = torch.tensor([
            [-1.0, 3.0, -3.0, 1.0],
            [3.0, -6.0, 3.0, 0.0],
            [-3.0, 3.0, 0.0, 0.0],
            [-1.0, 0.0, 0.0, 0.0],
        ]).view(1, 4, 4)
        self.register_buffer('w', w)

        self.yk_f = nn.Parameter(torch.zeros_like(xk))
        self.yk_g = nn.Parameter(torch.zeros_like(xk))
        self.yk_h = nn.Parameter(torch.zeros_like(xk))

    def get_ak(self, yk):
        ak_1 = 2 / 3 * yk[0] + 2 / 3 * yk[1] - 1 / 3 * yk[2]
        ak_else = yk[1:-1] - 1 / 6 * yk[:-2] + 1 / 6 * yk[2:]
        return torch.cat([ak_1.unsqueeze(0), ak_else], dim=0)

    def get_bk(self, yk):
        bk_else = yk[1:-1] + 1 / 6 * yk[:-2] - 1 / 6 * yk[2:]
        bk_m = 2 / 3 * yk[-1] + 2 / 3 * yk[-2] - 1 / 3 * yk[-3]
        return torch.cat([bk_else, bk_m.unsqueeze(0)], dim=0)

    def get_func(self, yk, lambd):
        indices = torch.searchsorted(self.xk, lambd, right=False).view(-1)
        indices[indices < 0] = 0
        indices[indices > self.num_side_points - 1] = self.num_side_points - 1

        ak = self.get_ak(yk)
        bk = self.get_bk(yk)

        y_left = yk[indices].view_as(lambd)
        y_right = yk[indices + 1].view_as(lambd)
        a = ak[indices].view_as(lambd)
        b = bk[indices].view_as(lambd)
        temp_right = torch.stack([y_left, a, b, y_right], dim=2)

        xi = (lambd - self.xk[indices].view_as(lambd)) / (self.xk[indices + 1].view_as(lambd) - self.xk[indices].view_as(lambd))
        xi_vector = torch.stack([xi**3, xi**2, xi, torch.ones_like(xi)], dim=2) # batch, #lambda, 4

        temp_left = torch.matmul(xi_vector, self.w) # batch, #lambda, 4
        func = (temp_left * temp_right).sum(dim=2) # batch, #lambda

        return func

    def forward(self, F: Tensor) -> Tensor:
        U, sigma, Vh = self.svd(F)

        f = self.get_func(self.yk_f, sigma)

        areas = torch.stack([
            sigma[:, 0] * sigma[:, 1],
            sigma[:, 1] * sigma[:, 2],
            sigma[:, 0] * sigma[:, 2]], dim=1)
        g = self.get_func(self.yk_g, areas)

        g1 = g[:, [0, 0, 2]] * sigma[:, [1, 0, 0]]
        g2 = g[:, [2, 1, 1]] * sigma[:, [2, 2, 1]]

        volume = (sigma[:, 0] * sigma[:, 1] * sigma[:, 2]).unsqueeze(1)
        h = self.get_func(self.yk_h, volume) * sigma[:, [1, 0, 0]] * sigma[:, [2, 2, 1]]

        new_sigma = f + g1 + g2 + h
        delta_Fp = self.alpha * torch.matmul(torch.matmul(U, torch.diag_embed(new_sigma)), Vh)

        Fp = delta_Fp + F
        return Fp


class SVDMetaPlasticity(MetaPlasticity):
    def __init__(self, cfg: DictConfig) -> None:
        super().__init__(cfg)

        self.layers = nn.ModuleList()

        width = self.dim
        for next_width in cfg.layer_widths:
            self.layers.append(MLPBlock(width, next_width, cfg.no_bias, cfg.norm, cfg.nonlinearity))
            width = next_width

        self.final_layer = MLPBlock(width, self.dim, cfg.no_bias, None, None)

        for m in self.modules():
            init_weight(m)

    def forward(self, F: Tensor) -> Tensor:
        U, sigma, Vh = self.svd(F)

        if self.normalize_input:
            x = sigma - 1.0
        else:
            x = sigma
        for layer in self.layers:
            x = layer(x)
        x = self.final_layer(x)

        delta_Fp = self.alpha * torch.matmul(torch.matmul(U, torch.diag_embed(x)), Vh)
        Fp = delta_Fp + F

        return Fp
